{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "GrooVAE.ipynb",
      "provenance": [],
      "collapsed_sections": [
        "CNtquLL_5VEs",
        "yPY395Uo-y9f",
        "KLwr71jntKdP",
        "ylJ8BX1cu0Cn",
        "JPJuKYs0u7f4",
        "1rylVpq0vB-3"
      ],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CNtquLL_5VEs",
        "colab_type": "text"
      },
      "source": [
        "# GrooVAE: Generating and Controlling Expressive Drum Performances\n",
        "### ___Jon Gillick, Adam Roberts, Jesse Engel___\n",
        "\n",
        "####To open this notebook in Colab visit https://goo.gl/magenta/groovae-colab\n",
        "\n",
        "---\n",
        "\n",
        "This notebook demonstrates some applications of machine learning for generating and manipulating beats and drum performances.  Additional details can be found in our [paper](https://goo.gl/magenta/groovae-paper) and [blog post](https://g.co/magenta/groovae).\n",
        "\n",
        "To make these experiments possible, we hired some talented professional drummers to record on an electronic drum kit (see the [Groove MIDI Dataset](https://g.co/magenta/groove-datasets) for more details), and then we trained our \"GrooVAE\" models on this data. \n",
        "\n",
        "<br>\n",
        "\n",
        "One way to think about a MIDI drum beat, whether it is played live or electronically sequenced, is to break it down into 2 main components: \n",
        "\n",
        "<ul>\n",
        "  <li> The Score (which drums are played, as written in western music notation) </li> \n",
        "  <li> The Groove (how the drums are played, i.e. dynamics and timing) </li> \n",
        "\n",
        " </ul>\n",
        " \n",
        " <br>\n",
        " \n",
        "One (simplified) view of a drum beat is that it is the combination of a score and a groove.  Given one or the other, a good drummer knows how to fill in the rest to come up with a complete beat - in this project, we try to teach models to do this completion.\n",
        " \n",
        " \n",
        " \n",
        " \n",
        " <img src=\"https://magenta-staging.tensorflow.org/assets/groovae/score-groove.png\" alt=\"GrooVAE Figure\" >\n",
        "\n",
        "\n",
        "\n",
        "\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "yPY395Uo-y9f",
        "colab_type": "text"
      },
      "source": [
        "# Environment Setup"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "I47LxktGbwMF",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title Setup Environment\n",
        "\n",
        "print('Installing dependencies...')\n",
        "\n",
        "!apt-get update -qq && apt-get install -qq libfluidsynth2 fluid-soundfont-gm build-essential libasound2-dev libjack-dev\n",
        "!pip install -q pyfluidsynth\n",
        "!pip install -U -q magenta\n",
        "\n",
        "import tensorflow_datasets as tfds\n",
        "import tensorflow as tf\n",
        "\n",
        "# Allow python to pick up the newly-installed fluidsynth lib.\n",
        "# This is only needed for the hosted Colab environment.\n",
        "import ctypes.util\n",
        "orig_ctypes_util_find_library = ctypes.util.find_library\n",
        "def proxy_find_library(lib):\n",
        "  if lib == 'fluidsynth':\n",
        "    return 'libfluidsynth.so.1'\n",
        "  else:\n",
        "    return orig_ctypes_util_find_library(lib)\n",
        "ctypes.util.find_library = proxy_find_library\n",
        "  \n",
        "print('Importing software libraries...')\n",
        "\n",
        "import copy, warnings, librosa, numpy as np\n",
        "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n",
        "\n",
        "\n",
        "# Colab/Notebook specific stuff\n",
        "import IPython.display\n",
        "from IPython.display import Audio\n",
        "from google.colab import files\n",
        "\n",
        "# Magenta specific stuff\n",
        "from magenta.models.music_vae import configs\n",
        "from magenta.models.music_vae.trained_model import TrainedModel\n",
        "from magenta.models.music_vae import data\n",
        "import note_seq\n",
        "from note_seq import midi_synth\n",
        "from note_seq.sequences_lib import concatenate_sequences\n",
        "from note_seq.protobuf import music_pb2\n",
        "\n",
        "# Define some functions\n",
        "\n",
        "# If a sequence has notes at time before 0.0, scootch them up to 0\n",
        "def start_notes_at_0(s):\n",
        "  for n in s.notes:\n",
        "    if n.start_time < 0:\n",
        "      n.end_time -= n.start_time\n",
        "      n.start_time = 0\n",
        "  return s\n",
        "\n",
        "def play(note_sequence, sf2_path='Standard_Drum_Kit.sf2'):  \n",
        "  if sf2_path:\n",
        "    audio_seq = midi_synth.fluidsynth(start_notes_at_0(note_sequence), sample_rate=44100, sf2_path=sf2_path)\n",
        "    IPython.display.display(IPython.display.Audio(audio_seq, rate=44100))\n",
        "  else:\n",
        "    note_seq.play_sequence(start_notes_at_0(note_sequence), synth=note_seq.fluidsynth)\n",
        "\n",
        "# Some midi files come by default from different instrument channels\n",
        "# Quick and dirty way to set midi files to be recognized as drums\n",
        "def set_to_drums(ns):\n",
        "  for n in ns.notes:\n",
        "    n.instrument=9\n",
        "    n.is_drum = True\n",
        "    \n",
        "def unset_to_drums(ns):\n",
        "  for note in ns.notes:\n",
        "    note.is_drum=False\n",
        "    note.instrument=0\n",
        "  return ns\n",
        "\n",
        "# quickly change the tempo of a midi sequence and adjust all notes\n",
        "def change_tempo(note_sequence, new_tempo):\n",
        "  new_sequence = copy.deepcopy(note_sequence)\n",
        "  ratio = note_sequence.tempos[0].qpm / new_tempo\n",
        "  for note in new_sequence.notes:\n",
        "    note.start_time = note.start_time * ratio\n",
        "    note.end_time = note.end_time * ratio\n",
        "  new_sequence.tempos[0].qpm = new_tempo\n",
        "  return new_sequence\n",
        "\n",
        "def download(note_sequence, filename):\n",
        "  note_seq.sequence_proto_to_midi_file(note_sequence, filename)\n",
        "  files.download(filename)\n",
        "  \n",
        "def download_audio(audio_sequence, filename, sr):\n",
        "  librosa.output.write_wav(filename, audio_sequence, sr=sr, norm=True)\n",
        "  files.download(filename)\n",
        " \n",
        "# Load some configs to be used later\n",
        "dc_quantize = configs.CONFIG_MAP['groovae_2bar_humanize'].data_converter\n",
        "dc_tap = configs.CONFIG_MAP['groovae_2bar_tap_fixed_velocity'].data_converter\n",
        "dc_hihat = configs.CONFIG_MAP['groovae_2bar_add_closed_hh'].data_converter\n",
        "dc_4bar = configs.CONFIG_MAP['groovae_4bar'].data_converter\n",
        "\n",
        "# quick method for removing microtiming and velocity from a sequence\n",
        "def get_quantized_2bar(s, velocity=0):\n",
        "  new_s = dc_quantize.from_tensors(dc_quantize.to_tensors(s).inputs)[0]\n",
        "  new_s = change_tempo(new_s, s.tempos[0].qpm)\n",
        "  if velocity != 0:\n",
        "    for n in new_s.notes:\n",
        "      n.velocity = velocity\n",
        "  return new_s\n",
        "\n",
        "# quick method for turning a drumbeat into a tapped rhythm\n",
        "def get_tapped_2bar(s, velocity=85, ride=False):\n",
        "  new_s = dc_tap.from_tensors(dc_tap.to_tensors(s).inputs)[0]\n",
        "  new_s = change_tempo(new_s, s.tempos[0].qpm)\n",
        "  if velocity != 0:\n",
        "    for n in new_s.notes:\n",
        "      n.velocity = velocity\n",
        "  if ride:\n",
        "    for n in new_s.notes:\n",
        "      n.pitch = 42\n",
        "  return new_s\n",
        "\n",
        "# quick method for removing hi-hats from a sequence\n",
        "def get_hh_2bar(s):\n",
        "  new_s = dc_hihat.from_tensors(dc_hihat.to_tensors(s).inputs)[0]\n",
        "  new_s = change_tempo(new_s, s.tempos[0].qpm)\n",
        "  return new_s\n",
        "\n",
        "\n",
        "# Calculate quantization steps but do not remove microtiming\n",
        "def quantize(s, steps_per_quarter=4):\n",
        "  return note_seq.sequences_lib.quantize_note_sequence(s,steps_per_quarter)\n",
        "\n",
        "# Destructively quantize a midi sequence\n",
        "def flatten_quantization(s):\n",
        "  beat_length = 60. / s.tempos[0].qpm\n",
        "  step_length = beat_length / 4#s.quantization_info.steps_per_quarter\n",
        "  new_s = copy.deepcopy(s)\n",
        "  for note in new_s.notes:\n",
        "    note.start_time = step_length * note.quantized_start_step\n",
        "    note.end_time = step_length * note.quantized_end_step\n",
        "  return new_s\n",
        "\n",
        "# Calculate how far off the beat a note is\n",
        "def get_offset(s, note_index):\n",
        "  q_s = flatten_quantization(quantize(s))\n",
        "  true_onset = s.notes[note_index].start_time\n",
        "  quantized_onset = q_s.notes[note_index].start_time\n",
        "  diff = quantized_onset - true_onset\n",
        "  beat_length = 60. / s.tempos[0].qpm\n",
        "  step_length = beat_length / 4#q_s.quantization_info.steps_per_quarter\n",
        "  offset = diff/step_length\n",
        "  return offset\n",
        "\n",
        "def is_4_4(s):\n",
        "  ts = s.time_signatures[0]\n",
        "  return (ts.numerator == 4 and ts.denominator ==4)\n",
        "\n",
        "def preprocess_4bar(s):\n",
        "  return dc_4bar.from_tensors(dc_4bar.to_tensors(s).outputs)[0]\n",
        "\n",
        "def preprocess_2bar(s):\n",
        "  return dc_quantize.from_tensors(dc_quantize.to_tensors(s).outputs)[0]\n",
        "\n",
        "def _slerp(p0, p1, t):\n",
        "  \"\"\"Spherical linear interpolation.\"\"\"\n",
        "  omega = np.arccos(np.dot(np.squeeze(p0/np.linalg.norm(p0)),\n",
        "    np.squeeze(p1/np.linalg.norm(p1))))\n",
        "  so = np.sin(omega)\n",
        "  return np.sin((1.0-t)*omega) / so * p0 + np.sin(t*omega)/so * p1\n",
        "\n",
        "print('Downloading drum samples...')\n",
        "# Download a drum kit for playing drum midi\n",
        "!gsutil -q -m cp gs://magentadata/soundfonts/Standard_Drum_Kit.sf2 .\n",
        "\n",
        "print(\"Download MIDI data...\")\n",
        "\n",
        "# Load MIDI files from GMD with MIDI only (no audio) as a tf.data.Dataset\n",
        "dataset_2bar = tfds.as_numpy(tfds.load(\n",
        "    name=\"groove/2bar-midionly\",\n",
        "    split=tfds.Split.VALIDATION,\n",
        "    try_gcs=True))\n",
        "\n",
        "dev_sequences = [quantize(note_seq.midi_to_note_sequence(features[\"midi\"])) for features in dataset_2bar]\n",
        "_ = [set_to_drums(s) for s in dev_sequences]\n",
        "dev_sequences = [s for s in dev_sequences if is_4_4(s) and len(s.notes) > 0 and s.notes[-1].quantized_end_step > note_seq.steps_per_bar_in_quantized_sequence(s)]\n",
        "\n",
        "dataset_4bar = tfds.as_numpy(tfds.load(\n",
        "    name=\"groove/4bar-midionly\",\n",
        "    split=tfds.Split.VALIDATION,\n",
        "    try_gcs=True))\n",
        "\n",
        "dev_sequences_4bar = [quantize(note_seq.midi_to_note_sequence(features[\"midi\"])) for features in dataset_4bar]\n",
        "_ = [set_to_drums(s) for s in dev_sequences_4bar]\n",
        "dev_sequences_4bar = [s for s in dev_sequences_4bar if is_4_4(s) and len(s.notes) > 0 and s.notes[-1].quantized_end_step > note_seq.steps_per_bar_in_quantized_sequence(s)]\n",
        "\n",
        "\n",
        "print(\"Loading model checkpoints...\")\n",
        "\n",
        "# Download all the models\n",
        "!gsutil -q -m cp gs://magentadata/models/music_vae/checkpoints/groovae_*.tar .\n",
        "GROOVAE_4BAR = \"groovae_4bar.tar\"\n",
        "GROOVAE_2BAR_HUMANIZE = \"groovae_2bar_humanize.tar\"\n",
        "GROOVAE_2BAR_HUMANIZE_NOKL = \"groovae_2bar_humanize_nokl.tar\"\n",
        "GROOVAE_2BAR_HITS_CONTROL = \"groovae_2bar_hits_control.tar\"\n",
        "GROOVAE_2BAR_TAP_FIXED_VELOCITY = \"groovae_2bar_tap_fixed_velocity.tar\"\n",
        "GROOVAE_2BAR_ADD_CLOSED_HH = \"groovae_2bar_add_closed_hh.tar\"\n",
        "GROOVAE_2BAR_HITS_CONTROL_NOKL = \"groovae_2bar_hits_control_nokl.tar\"\n",
        "\n",
        "print(\"Downloading audio data...\")\n",
        "!gsutil -q -m cp gs://magentadata/models/music_vae/groovae_colab/*wav ."
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KLwr71jntKdP",
        "colab_type": "text"
      },
      "source": [
        "# Generate New Beats"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UELPaZi1CrTg",
        "colab_type": "text"
      },
      "source": [
        "Before we get more specific, let's generate some beats from scratch.  One of the powerful abilities of Variational Autoencoder models is to generate new datapoints similar to the ones they were trained on. Like [MusicVAE](g.co/magenta/music-vae), we can sample as many new beats from our latent space as we would like, but with GrooVAE, our latent space encodes not just the drum pattern but also the performances characteristics of the drummers who played them.  We can also interpolate smoothly between different beats in our latent space."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KWTs1HYXtM_d",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title Load checkpoint\n",
        "\n",
        "config_4_bar = configs.CONFIG_MAP['groovae_4bar']\n",
        "groovae_4_bar = TrainedModel(config_4_bar, 2, checkpoint_dir_or_path=GROOVAE_4BAR)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jlh97al4ADyU",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title Generate Beats\n",
        "temperature = 1. #@param {type:\"slider\", min:0.01, max:2.0, step:0.01}\n",
        "tempo = 116 #@param {type:\"slider\", min:80, max:180, step:1}\n",
        "samples = groovae_4_bar.sample(3,temperature=temperature,length=64)\n",
        "samples = [change_tempo(start_notes_at_0(s),tempo) for s in samples]\n",
        "for s in samples:\n",
        "  play(s)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "f1tXIwh9bC5E",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title Interpolate Between Beats\n",
        "temperature = 1 #@param {type:\"slider\", min:0.01, max:2.0, step:0.01}\n",
        "steps = 3 #@param {type:\"slider\", min:1, max:5, step:1}\n",
        "\n",
        "sequence_indices = np.random.randint(0,len(dev_sequences_4bar), 2)\n",
        "beat_a = change_tempo(dev_sequences_4bar[sequence_indices[0]], 120)\n",
        "beat_a = preprocess_4bar(beat_a)\n",
        "beat_b = change_tempo(dev_sequences_4bar[sequence_indices[1]], 120)\n",
        "beat_b = preprocess_4bar(beat_b)\n",
        "\n",
        "print(\"Playing Beat A\")\n",
        "play(beat_a)\n",
        "print(\"Playing Beat B\")\n",
        "play(beat_b)\n",
        "\n",
        "seqs = groovae_4_bar.interpolate(beat_a, beat_b, steps + 2, length=64, temperature=1.)\n",
        "\n",
        "individual_duration = 8.0\n",
        "\n",
        "interp_seq = concatenate_sequences(seqs, [individual_duration] * len(seqs))\n",
        "\n",
        "print(\"Playing Interpolation from A to B\")\n",
        "\n",
        "play(start_notes_at_0(interp_seq))\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ylJ8BX1cu0Cn",
        "colab_type": "text"
      },
      "source": [
        "# Groove: Add some groove to a programmed beat"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mLjJvWp_LRX4",
        "colab_type": "text"
      },
      "source": [
        "Now let's see what it sounds like to add groove to a quantized beat.  This is a function that's often used in music production to give drums more character. In the past, it's typically been done by randomizing note timings and velocities or by fixing all timings and velocities to specific values defined by a \"swing\" setting or a template.  Here, instead we let the model predict what the groove characteristics should be, adapting the timing and velocities based on what the beat is."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "b_O6QE8Q35fi",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title Load checkpoint\n",
        "\n",
        "config_2_bar_humanize = configs.CONFIG_MAP['groovae_2bar_humanize']\n",
        "groovae_2_bar_humanize = TrainedModel(config_2_bar_humanize, 1, checkpoint_dir_or_path=GROOVAE_2BAR_HUMANIZE)\n",
        "\n",
        "def humanize(s, model, temperature=1.0):  \n",
        "  encoding, mu, sigma = model.encode([s])\n",
        "  decoded = model.decode(encoding, length=32, temperature=1.)[0]\n",
        "  return change_tempo(decoded, s.tempos[0].qpm)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "09ypAIX15DrX",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title Groove up the Beats\n",
        "\n",
        "sequence_indices = np.random.randint(0,len(dev_sequences), 3) \n",
        "for i in sequence_indices:\n",
        "  s = start_notes_at_0(dev_sequences[i])\n",
        "  s = get_quantized_2bar(s, velocity=85)\n",
        "  print(\"\\nPlaying programmed beat: \")\n",
        "  play(s)\n",
        "  h = humanize(s, groovae_2_bar_humanize)\n",
        "  print(\"Playing humanized beat: \")\n",
        "  play(start_notes_at_0(h))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "R5zfOIpeNlvv",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title (Optional) Upload your own quantized MIDI beats to Groove on\n",
        "\n",
        "uploaded = files.upload()\n",
        "\n",
        "uploaded_sequences = [note_seq.midi_file_to_note_sequence(fn) for fn in uploaded.keys()]  \n",
        "\n",
        "new_beats = []\n",
        "\n",
        "for s in uploaded_sequences:\n",
        "  set_to_drums(s)\n",
        "  s = start_notes_at_0(s)\n",
        "  #s = get_quantized_2bar(s, velocity=85)\n",
        "  print(\"\\nPlaying your beat: \")\n",
        "  play(s)\n",
        "  h = humanize(s, groovae_2_bar_humanize)\n",
        "  print(\"Playing humanized beat: \")\n",
        "  play(start_notes_at_0(h))\n",
        "  new_beats.append(h)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oAPK3eXuN7Ho",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title (Optional) Save your GrooVAE beats\n",
        "for i, beat in enumerate(new_beats):\n",
        "  download(beat, 'humanized_beat_%d.mid' %(i))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JPJuKYs0u7f4",
        "colab_type": "text"
      },
      "source": [
        "# Tap2Drum: Generate a beat from any rhythm "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GxkHBd_uGFNP",
        "colab_type": "text"
      },
      "source": [
        "While the Groove model works by removing the micro-timing and velocity information and learning to predict them from just the drum pattern, we can also go in the opposite direction.  Here, we take a representation of a Groove as input (in the form of a rhythm that can have precise timing but where drum categories are ignored) - and then generate drum beats that match the groove implied by this rhythm.  We trained this model by collapsing all drum hits from each beat in the training data to a single \"tapped\" rhythm, and then learning to decode full beats from that rhythm.  This allows us to input any rhythm we like through the precise onset timings in a \"tap\" and let the model decode our rhythm into a beat. We can even simply record taps as audio, or extract them from a recording of another instrument, rather than needing a midi controller."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab_type": "code",
        "id": "au0HjIkZuBW4",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title Load checkpoint\n",
        "\n",
        "config_2bar_tap = configs.CONFIG_MAP['groovae_2bar_tap_fixed_velocity']\n",
        "groovae_2bar_tap = TrainedModel(config_2bar_tap, 1, checkpoint_dir_or_path=GROOVAE_2BAR_TAP_FIXED_VELOCITY)\n",
        "\n",
        "def mix_tracks(y1, y2, stereo = False):\n",
        "  l = max(len(y1),len(y2))\n",
        "  y1 = librosa.util.fix_length(y1, l)\n",
        "  y2 = librosa.util.fix_length(y2, l)\n",
        "  \n",
        "  if stereo:\n",
        "    return np.vstack([y1, y2])  \n",
        "  else:\n",
        "    return y1+y2\n",
        "\n",
        "def make_click_track(s):\n",
        "  last_note_time = max([n.start_time for n in s.notes])\n",
        "  beat_length = 60. / s.tempos[0].qpm \n",
        "  i = 0\n",
        "  times = []\n",
        "  while i*beat_length < last_note_time:\n",
        "    times.append(i*beat_length)\n",
        "    i += 1\n",
        "  return librosa.clicks(times)\n",
        "\n",
        "def drumify(s, model, temperature=1.0): \n",
        "  encoding, mu, sigma = model.encode([s])\n",
        "  decoded = model.decode(encoding, length=32, temperature=temperature)\n",
        "  return decoded[0]\n",
        "\n",
        "def combine_sequences(seqs):\n",
        "  # assumes a list of 2 bar seqs with constant tempo\n",
        "  for i, seq in enumerate(seqs):\n",
        "    shift_amount = i*(60 / seqs[0].tempos[0].qpm * 4 * 2)\n",
        "    if shift_amount > 0:\n",
        "      seqs[i] = note_seq.sequences_lib.shift_sequence_times(seq, shift_amount)\n",
        "  return note_seq.sequences_lib.concatenate_sequences(seqs)\n",
        "\n",
        "def combine_sequences_with_lengths(sequences, lengths):\n",
        "  seqs = copy.deepcopy(sequences)\n",
        "  total_shift_amount = 0\n",
        "  for i, seq in enumerate(seqs):\n",
        "    if i == 0:\n",
        "      shift_amount = 0\n",
        "    else:\n",
        "      shift_amount = lengths[i-1]\n",
        "    total_shift_amount += shift_amount\n",
        "    if total_shift_amount > 0:\n",
        "      seqs[i] = note_seq.sequences_lib.shift_sequence_times(seq, total_shift_amount)\n",
        "  combined_seq = music_pb2.NoteSequence()\n",
        "  for i in range(len(seqs)):\n",
        "    tempo = combined_seq.tempos.add()\n",
        "    tempo.qpm = seqs[i].tempos[0].qpm\n",
        "    tempo.time = sum(lengths[0:i-1])\n",
        "    for note in seqs[i].notes:\n",
        "      combined_seq.notes.extend([copy.deepcopy(note)])\n",
        "  return combined_seq\n",
        "\n",
        "def get_audio_start_time(y, sr):\n",
        "  tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sr)\n",
        "  beat_times = librosa.frames_to_time(beat_frames, sr=sr)\n",
        "  onset_times = librosa.onset.onset_detect(y, sr, units='time')\n",
        "  start_time = onset_times[0] \n",
        "  return start_time\n",
        "\n",
        "def audio_tap_to_note_sequence(f, velocity_threshold=30):\n",
        "  y, sr = librosa.load(f)\n",
        "  # pad the beginning to avoid errors with onsets right at the start\n",
        "  y = np.concatenate([np.zeros(1000),y])\n",
        "  tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sr)\n",
        "  # try to guess reasonable tempo\n",
        "  beat_times = librosa.frames_to_time(beat_frames, sr=sr)\n",
        "  onset_frames = librosa.onset.onset_detect(y, sr, units='frames')\n",
        "  onset_times = librosa.onset.onset_detect(y, sr, units='time')\n",
        "  start_time = onset_times[0]\n",
        "  onset_strengths = librosa.onset.onset_strength(y, sr)[onset_frames]\n",
        "  normalized_onset_strengths = onset_strengths / np.max(onset_strengths)\n",
        "  onset_velocities = np.int32(normalized_onset_strengths * 127)\n",
        "  note_sequence = music_pb2.NoteSequence()\n",
        "  note_sequence.tempos.add(qpm=tempo)\n",
        "  for onset_vel, onset_time in zip(onset_velocities, onset_times):\n",
        "    if onset_vel > velocity_threshold and onset_time >= start_time:  # filter quietest notes\n",
        "      note_sequence.notes.add(\n",
        "        instrument=9, pitch=42, is_drum=True,\n",
        "        velocity=onset_vel,  # use fixed velocity here to avoid overfitting\n",
        "        start_time=onset_time - start_time,\n",
        "        end_time=onset_time - start_time)\n",
        "\n",
        "  return note_sequence\n",
        "\n",
        "# Allow encoding of a sequence that has no extracted examples\n",
        "# by adding a quiet note after the desired length of time\n",
        "def add_silent_note(note_sequence, num_bars):\n",
        "  tempo = note_sequence.tempos[0].qpm\n",
        "  length = 60/tempo * 4 * num_bars\n",
        "  note_sequence.notes.add(\n",
        "    instrument=9, pitch=42, velocity=0, start_time=length-0.02, \n",
        "    end_time=length-0.01, is_drum=True)\n",
        "  \n",
        "def get_bar_length(note_sequence):\n",
        "  tempo = note_sequence.tempos[0].qpm\n",
        "  return 60/tempo * 4\n",
        "\n",
        "def sequence_is_shorter_than_full(note_sequence):\n",
        "  return note_sequence.notes[-1].start_time < get_bar_length(note_sequence)\n",
        "\n",
        "def get_rhythm_elements(y, sr):\n",
        "  onset_env = librosa.onset.onset_strength(y, sr=sr)\n",
        "  tempo = librosa.beat.tempo(onset_envelope=onset_env, max_tempo=180)[0]\n",
        "  onset_times = librosa.onset.onset_detect(y, sr, units='time')\n",
        "  onset_frames = librosa.onset.onset_detect(y, sr, units='frames')\n",
        "  onset_strengths = librosa.onset.onset_strength(y, sr)[onset_frames]\n",
        "  normalized_onset_strengths = onset_strengths / np.max(onset_strengths)\n",
        "  onset_velocities = np.int32(normalized_onset_strengths * 127)\n",
        "\n",
        "  return tempo, onset_times, onset_frames, onset_velocities\n",
        "\n",
        "def make_tap_sequence(tempo, onset_times, onset_frames, onset_velocities,\n",
        "                     velocity_threshold, start_time, end_time):\n",
        "  note_sequence = music_pb2.NoteSequence()\n",
        "  note_sequence.tempos.add(qpm=tempo)\n",
        "  for onset_vel, onset_time in zip(onset_velocities, onset_times):\n",
        "    if onset_vel > velocity_threshold and onset_time >= start_time and onset_time < end_time:  # filter quietest notes\n",
        "      note_sequence.notes.add(\n",
        "        instrument=9, pitch=42, is_drum=True,\n",
        "        velocity=onset_vel,  # model will use fixed velocity here\n",
        "        start_time=onset_time - start_time,\n",
        "        end_time=onset_time -start_time + 0.01\n",
        "      )\n",
        "  return note_sequence\n",
        "\n",
        "def audio_to_drum(f, velocity_threshold=30, temperature=1., force_sync=False, start_windows_on_downbeat=False):\n",
        "  y, sr = librosa.load(f)\n",
        "  # pad the beginning to avoid errors with onsets right at the start\n",
        "  y = np.concatenate([np.zeros(1000),y])\n",
        "\n",
        "  clip_length = float(len(y))/sr\n",
        "\n",
        "  tap_sequences = []\n",
        "  # Loop through the file, grabbing 2-bar sections at a time, estimating\n",
        "  # tempos along the way to try to handle tempo variations\n",
        "\n",
        "  tempo, onset_times, onset_frames, onset_velocities = get_rhythm_elements(y, sr)\n",
        "\n",
        "  initial_start_time = onset_times[0]\n",
        "\n",
        "  start_time = onset_times[0]\n",
        "  beat_length = 60/tempo\n",
        "  two_bar_length = beat_length * 8\n",
        "  end_time = start_time + two_bar_length\n",
        "\n",
        "  start_times = []\n",
        "  lengths = []\n",
        "  tempos = []\n",
        "\n",
        "  start_times.append(start_time)\n",
        "  lengths.append(end_time-start_time)\n",
        "  tempos.append(tempo)\n",
        "\n",
        "  tap_sequences.append(make_tap_sequence(tempo, onset_times, onset_frames, \n",
        "                       onset_velocities, velocity_threshold, start_time, end_time))\n",
        "\n",
        "  start_time += two_bar_length; end_time += two_bar_length\n",
        "\n",
        "\n",
        "  while start_time < clip_length:\n",
        "    start_sample = int(librosa.core.time_to_samples(start_time, sr=sr))\n",
        "    end_sample = int(librosa.core.time_to_samples(start_time + two_bar_length, sr=sr))\n",
        "    current_section = y[start_sample:end_sample]\n",
        "    tempo = librosa.beat.tempo(onset_envelope=librosa.onset.onset_strength(current_section, sr=sr), max_tempo=180)[0]\n",
        "\n",
        "    beat_length = 60/tempo\n",
        "    two_bar_length = beat_length * 8\n",
        "\n",
        "    end_time = start_time + two_bar_length\n",
        "\n",
        "    start_times.append(start_time)\n",
        "    lengths.append(end_time-start_time)\n",
        "    tempos.append(tempo)\n",
        "\n",
        "    tap_sequences.append(make_tap_sequence(tempo, onset_times, onset_frames, \n",
        "                         onset_velocities, velocity_threshold, start_time, end_time))\n",
        "\n",
        "    start_time += two_bar_length; end_time += two_bar_length\n",
        "  \n",
        "  # if there's a long gap before the first note, back it up close to 0\n",
        "  def _shift_notes_to_beginning(s):\n",
        "    start_time = s.notes[0].start_time\n",
        "    if start_time > 0.1:\n",
        "      for n in s.notes:\n",
        "        n.start_time -= start_time\n",
        "        n.end_time -=start_time\n",
        "    return start_time\n",
        "      \n",
        "  def _shift_notes_later(s, start_time):\n",
        "    for n in s.notes:\n",
        "      n.start_time += start_time\n",
        "      n.end_time +=start_time    \n",
        "  \n",
        "  def _sync_notes_with_onsets(s, onset_times):\n",
        "    for n in s.notes:\n",
        "      n_length = n.end_time - n.start_time\n",
        "      closest_onset_index = np.argmin(np.abs(n.start_time - onset_times))\n",
        "      n.start_time = onset_times[closest_onset_index]\n",
        "      n.end_time = n.start_time + n_length\n",
        "  \n",
        "  drum_seqs = []\n",
        "  for s in tap_sequences:\n",
        "    try:\n",
        "      if sequence_is_shorter_than_full(s):\n",
        "        add_silent_note(s, 2)\n",
        "        \n",
        "      if start_windows_on_downbeat:\n",
        "        note_start_time = _shift_notes_to_beginning(s)\n",
        "      h = drumify(s, groovae_2bar_tap, temperature=temperature)\n",
        "      h = change_tempo(h, s.tempos[0].qpm)\n",
        "      \n",
        "      if start_windows_on_downbeat and note_start_time > 0.1:\n",
        "          _shift_notes_later(s, note_start_time)\n",
        "        \n",
        "      drum_seqs.append(h)\n",
        "    except:\n",
        "      continue  \n",
        "      \n",
        "  combined_tap_sequence = start_notes_at_0(combine_sequences_with_lengths(tap_sequences, lengths))\n",
        "  combined_drum_sequence = start_notes_at_0(combine_sequences_with_lengths(drum_seqs, lengths))\n",
        "  \n",
        "  if force_sync:\n",
        "    _sync_notes_with_onsets(combined_tap_sequence, onset_times)\n",
        "    _sync_notes_with_onsets(combined_drum_sequence, onset_times)\n",
        "  \n",
        "  full_tap_audio = librosa.util.normalize(midi_synth.fluidsynth(combined_tap_sequence, sample_rate=sr))\n",
        "  full_drum_audio = librosa.util.normalize(midi_synth.fluidsynth(combined_drum_sequence, sample_rate=sr))\n",
        "  \n",
        "  tap_and_onsets = mix_tracks(full_tap_audio, y[int(initial_start_time*sr):]/2, stereo=True)\n",
        "  drums_and_original = mix_tracks(full_drum_audio, y[int(initial_start_time*sr):]/2, stereo=True)\n",
        "  \n",
        "  return full_drum_audio, full_tap_audio, tap_and_onsets, drums_and_original, combined_drum_sequence"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "RtEPAgNISX5g",
        "colab_type": "text"
      },
      "source": [
        "Here are a couple of examples using MIDI rhythms:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jFbxCIHRVSDl",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title MIDI Taps --> Beats\n",
        "\n",
        "sequence_indices = [1111, 366]\n",
        "for i in sequence_indices:\n",
        "  s = start_notes_at_0(dev_sequences[i])\n",
        "  s = change_tempo(get_tapped_2bar(s, velocity=85, ride=True), dev_sequences[i].tempos[0].qpm)\n",
        "  print(\"\\nPlaying Tapped Beat: \")\n",
        "  play(start_notes_at_0(s))\n",
        "  h = change_tempo(drumify(s, groovae_2bar_tap), s.tempos[0].qpm)\n",
        "  print(\"Playing Drummed Beat: \")\n",
        "  play(start_notes_at_0(h))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pti3slkJzQHL",
        "colab_type": "text"
      },
      "source": [
        "And a couple of examples using audio:"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ucrlM_djzSQ0",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title Audio Taps --> Beats\n",
        "\n",
        "paths = ['clap.wav', 'bbox.wav']\n",
        "temperature = 1.32 #@param {type:\"slider\", min:0.01, max:2.0, step:0.01}\n",
        "velocity_threshold = 0.05 #@param {type:\"slider\", min:0, max:1, step:0.01}\n",
        "stereo = True #@param {type:\"boolean\"}\n",
        "\n",
        "new_beats = []\n",
        "new_drum_audios = []\n",
        "combined_audios = []\n",
        "\n",
        "for i in range(len(paths)):\n",
        "    f = paths[i]\n",
        "    print(\"\\n\\n\\nPlaying %s: \" %(f))\n",
        "    y,sr = librosa.load(paths[i])\n",
        "    IPython.display.display(IPython.display.Audio(y, rate=sr))\n",
        "\n",
        "    full_drum_audio, full_tap_audio, tap_and_onsets, drums_and_original, combined_drum_sequence = audio_to_drum(f, velocity_threshold=velocity_threshold, temperature=temperature)\n",
        "    new_beats.append(combined_drum_sequence)\n",
        "    new_drum_audios.append(full_drum_audio)\n",
        "    combined_audios.append(drums_and_original)\n",
        "    print(\"Playing the rhythm detected in %s: \" %(f))\n",
        "    IPython.display.display(IPython.display.Audio(full_tap_audio, rate=sr))\n",
        "    print(\"Playing drums generated from %s: \" %(f))\n",
        "    IPython.display.display(IPython.display.Audio(full_drum_audio, rate=sr))\n",
        "    print(\"Playing %s together with drums\" %(f))\n",
        "    IPython.display.display(IPython.display.Audio(drums_and_original, rate=sr))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9V8jgj2myvmB",
        "colab_type": "text"
      },
      "source": [
        "The model we are using here is only set up to handle 2 measure clips at a constant tempo, so this works best with exactly that - 2 measures of audio starting on a downbeat. But it can be fun to try longer clips or nonmusical audio and see what happens."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cBTsDlWqRzFg",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title (Optional) Upload your own Audio files to Drumify\n",
        "\n",
        "temperature = 1.32 #@param {type:\"slider\", min:0.01, max:2.0, step:0.01}\n",
        "velocity_threshold = 0.05 #@param {type:\"slider\", min:0, max:1, step:0.01}\n",
        "stereo = True #@param {type:\"boolean\"}\n",
        "\n",
        "uploaded = files.upload()\n",
        "\n",
        "new_beats = []\n",
        "new_drum_audios = []\n",
        "combined_audios = []\n",
        "\n",
        "for i in range(len(uploaded)):\n",
        "    f = uploaded.keys()[i]\n",
        "    print(\"\\n\\n\\nPlaying %s: \" %(f))\n",
        "    y,sr = librosa.load(uploaded.keys()[i])\n",
        "    IPython.display.display(IPython.display.Audio(y, rate=sr))\n",
        "\n",
        "    full_drum_audio, full_tap_audio, tap_and_onsets, drums_and_original, combined_drum_sequence = audio_to_drum(f, velocity_threshold=velocity_threshold, temperature=temperature)\n",
        "    new_beats.append(combined_drum_sequence)\n",
        "    new_drum_audios.append(full_drum_audio)\n",
        "    combined_audios.append(drums_and_original)\n",
        "    print(\"Playing the rhythm detected in %s: \" %(f))\n",
        "    IPython.display.display(IPython.display.Audio(full_tap_audio, rate=sr))\n",
        "    print(\"Playing drums generated from %s: \" %(f))\n",
        "    IPython.display.display(IPython.display.Audio(full_drum_audio, rate=sr))\n",
        "    print(\"Playing %s together with drums\" %(f))\n",
        "    IPython.display.display(IPython.display.Audio(drums_and_original, rate=sr))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "rLWv48U_h5Kn",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title (Optional) Download your MIDI and Audio files\n",
        "print(\"Creating MIDI files for download...\")\n",
        "for i, beat in enumerate(new_beats):\n",
        "  download(beat, 'drumified_beat_%d.mid' %(i))\n",
        "\n",
        "#print(\"Creating Audio files for download. This may take a minute...\")\n",
        "for i, aud in enumerate(new_drum_audios):\n",
        "  download_audio(aud, 'drumified_beat_%d.wav' %(i), sr)\n",
        "  \n",
        "for i, aud in enumerate(combined_audios):\n",
        "  download_audio(aud, 'drumified_%d.wav' %(i), sr)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1rylVpq0vB-3",
        "colab_type": "text"
      },
      "source": [
        "# Transfer a Groove from one Beat to another"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ST5OYtGtFk6y",
        "colab_type": "text"
      },
      "source": [
        "One other fun use of the GrooVAE models is for \"Groove Transfer\". Lets load two random beats from our dataset and see how it feels like combine the \"groove\" from one beat with the drum pattern from the other.  Then we'll take a look at what it would sound like to move smoothly, or interpolate, through the \"space\" of possible grooves. Another interesting possibility that allows for easy control is to learn to use a \"tap\" as a source groove."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "sziK6ojUvFUd",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title Load Checkpoint\n",
        "config_2bar_transfer = configs.CONFIG_MAP['groovae_2bar_hits_control_tfds']\n",
        "groovae_2bar_transfer = TrainedModel(config_2bar_transfer, 1, checkpoint_dir_or_path=GROOVAE_2BAR_HITS_CONTROL)\n",
        "transfer_converter = config_2bar_transfer.data_converter\n",
        "\n",
        "def transfer_groove(source_groove, target_beat_controls, model, temperature=1.0):  \n",
        "  groove_encoding, _, _ = model.encode([source_groove])\n",
        "  decoded = model.decode(groove_encoding, length=32, temperature=temperature, c_input=target_beat_controls)\n",
        "  return decoded[0]\n",
        "\n",
        "def transfer_groove_encoding(groove_encoding, target_beat_controls, model, temperature=1.0):\n",
        "  decoded = model.decode(groove_encoding, length=32, temperature=temperature, c_input=target_beat_controls)\n",
        "  return decoded[0]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ixgKB-adrXfl",
        "colab_type": "code",
        "cellView": "form",
        "colab": {}
      },
      "source": [
        "#@title\n",
        "sequence_indices = np.random.randint(0,len(dev_sequences), 2)\n",
        "\n",
        "source_groove = preprocess_2bar(change_tempo(dev_sequences[sequence_indices[0]], 120))\n",
        "target_beat = preprocess_2bar(change_tempo(dev_sequences[sequence_indices[1]], 120))\n",
        "controls = transfer_converter.to_tensors(target_beat).controls[0]\n",
        "\n",
        "new_beat = transfer_groove(source_groove, controls, groovae_2bar_transfer)\n",
        "\n",
        "print(\"Source Groove: \")\n",
        "play(start_notes_at_0(source_groove))\n",
        "print(\"Target Beat: \")\n",
        "play(start_notes_at_0(target_beat))\n",
        "print(\"Transferred: \")\n",
        "play(start_notes_at_0(new_beat))\n",
        "\n",
        "print(\"Interpolating the Groove\")\n",
        "\n",
        "num_steps = 5\n",
        "\n",
        "_, mu, _ = groovae_2bar_transfer.encode([target_beat, source_groove])\n",
        "z = np.array([_slerp(mu[0], mu[1], t) for t in np.linspace(0, 1, num_steps)]).squeeze()\n",
        "\n",
        "individual_duration = 4.0\n",
        "\n",
        "seqs = groovae_2bar_transfer.decode(z, length=32, temperature=1., c_input=controls)\n",
        "\n",
        "interp_seq = note_seq.sequences_lib.concatenate_sequences(\n",
        "      seqs, [individual_duration] * len(seqs))\n",
        "\n",
        "play(start_notes_at_0(interp_seq))"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}
